/*
 * Copyright 2012 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package io.netty.channel.socket.nio;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.FileRegion;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.util.internal.SocketUtils;
import io.netty.channel.nio.AbstractNioByteChannel;
import io.netty.channel.socket.DefaultSocketChannelConfig;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannelConfig;
import io.netty.util.concurrent.GlobalEventExecutor;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.concurrent.Executor;

/**
 * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based
 * implementation.
 */
public class NioSocketChannel extends AbstractNioByteChannel
        implements io.netty.channel.socket.SocketChannel
{
    private static final InternalLogger logger = InternalLoggerFactory
            .getInstance(NioSocketChannel.class);

    private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider
            .provider();

    private static SocketChannel newSocket(SelectorProvider provider)
    {
        try
        {
            /**
             * Use the {@link SelectorProvider} to open {@link SocketChannel}
             * and so remove condition in {@link SelectorProvider#provider()}
             * which is called by each SocketChannel.open() otherwise.
             *
             * See
             * <a href="https://github.com/netty/netty/issues/2308">#2308</a>.
             */
            return provider.openSocketChannel();
        }
        catch (IOException e)
        {
            throw new ChannelException("Failed to open a socket.", e);
        }
    }

    private final SocketChannelConfig config;

    /**
     * Create a new instance
     */
    public NioSocketChannel()
    {
        this(DEFAULT_SELECTOR_PROVIDER);
    }

    /**
     * Create a new instance using the given {@link SelectorProvider}.
     */
    public NioSocketChannel(SelectorProvider provider)
    {
        this(newSocket(provider));
    }

    /**
     * Create a new instance using the given {@link SocketChannel}.
     */
    public NioSocketChannel(SocketChannel socket)
    {
        this(null, socket);
    }

    /**
     * Create a new instance
     *
     * @param parent the {@link Channel} which created this instance or
     *        {@code null} if it was created by the user
     * @param socket the {@link SocketChannel} which will be used
     */
    public NioSocketChannel(Channel parent, SocketChannel socket)
    {
        super(parent, socket);
        config = new NioSocketChannelConfig(this, socket.socket());
    }

    @Override
    public ServerSocketChannel parent()
    {
        return (ServerSocketChannel) super.parent();
    }

    @Override
    public SocketChannelConfig config()
    {
        return config;
    }

    @Override
    protected SocketChannel javaChannel()
    {
        return (SocketChannel) super.javaChannel();
    }

    @Override
    public boolean isActive()
    {
        SocketChannel ch = javaChannel();
        return ch.isOpen() && ch.isConnected();
    }

    @Override
    public boolean isOutputShutdown()
    {
        return javaChannel().socket().isOutputShutdown() || !isActive();
    }

    @Override
    public boolean isInputShutdown()
    {
        return javaChannel().socket().isInputShutdown() || !isActive();
    }

    @Override
    public boolean isShutdown()
    {
        Socket socket = javaChannel().socket();
        return socket.isInputShutdown() && socket.isOutputShutdown()
                || !isActive();
    }

    @Override
    public InetSocketAddress localAddress()
    {
        return (InetSocketAddress) super.localAddress();
    }

    @Override
    public InetSocketAddress remoteAddress()
    {
        return (InetSocketAddress) super.remoteAddress();
    }

    @Override
    public ChannelFuture shutdownOutput()
    {
        return shutdownOutput(newPromise());
    }

    @Override
    public ChannelFuture shutdownOutput(final ChannelPromise promise)
    {
        Executor closeExecutor = ((NioSocketChannelUnsafe) unsafe())
                .prepareToClose();
        if (closeExecutor != null)
        {
            closeExecutor.execute(new Runnable()
            {
                @Override
                public void run()
                {
                    shutdownOutput0(promise);
                }
            });
        }
        else
        {
            EventLoop loop = eventLoop();
            if (loop.inEventLoop())
            {
                shutdownOutput0(promise);
            }
            else
            {
                loop.execute(new Runnable()
                {
                    @Override
                    public void run()
                    {
                        shutdownOutput0(promise);
                    }
                });
            }
        }
        return promise;
    }

    @Override
    public ChannelFuture shutdownInput()
    {
        return shutdownInput(newPromise());
    }

    @Override
    protected boolean isInputShutdown0()
    {
        return isInputShutdown();
    }

    @Override
    public ChannelFuture shutdownInput(final ChannelPromise promise)
    {
        Executor closeExecutor = ((NioSocketChannelUnsafe) unsafe())
                .prepareToClose();
        if (closeExecutor != null)
        {
            closeExecutor.execute(new Runnable()
            {
                @Override
                public void run()
                {
                    shutdownInput0(promise);
                }
            });
        }
        else
        {
            EventLoop loop = eventLoop();
            if (loop.inEventLoop())
            {
                shutdownInput0(promise);
            }
            else
            {
                loop.execute(new Runnable()
                {
                    @Override
                    public void run()
                    {
                        shutdownInput0(promise);
                    }
                });
            }
        }
        return promise;
    }

    @Override
    public ChannelFuture shutdown()
    {
        return shutdown(newPromise());
    }

    @Override
    public ChannelFuture shutdown(final ChannelPromise promise)
    {
        Executor closeExecutor = ((NioSocketChannelUnsafe) unsafe())
                .prepareToClose();
        if (closeExecutor != null)
        {
            closeExecutor.execute(new Runnable()
            {
                @Override
                public void run()
                {
                    shutdown0(promise);
                }
            });
        }
        else
        {
            EventLoop loop = eventLoop();
            if (loop.inEventLoop())
            {
                shutdown0(promise);
            }
            else
            {
                loop.execute(new Runnable()
                {
                    @Override
                    public void run()
                    {
                        shutdown0(promise);
                    }
                });
            }
        }
        return promise;
    }

    private void shutdownOutput0(final ChannelPromise promise)
    {
        try
        {
            shutdownOutput0();
            promise.setSuccess();
        }
        catch (Throwable t)
        {
            promise.setFailure(t);
        }
    }

    private void shutdownOutput0() throws Exception
    {
        if (PlatformDependent.javaVersion() >= 7)
        {
            javaChannel().shutdownOutput();
        }
        else
        {
            javaChannel().socket().shutdownOutput();
        }
    }

    private void shutdownInput0(final ChannelPromise promise)
    {
        try
        {
            shutdownInput0();
            promise.setSuccess();
        }
        catch (Throwable t)
        {
            promise.setFailure(t);
        }
    }

    private void shutdownInput0() throws Exception
    {
        if (PlatformDependent.javaVersion() >= 7)
        {
            javaChannel().shutdownInput();
        }
        else
        {
            javaChannel().socket().shutdownInput();
        }
    }

    private void shutdown0(final ChannelPromise promise)
    {
        Throwable cause = null;
        try
        {
            shutdownOutput0();
        }
        catch (Throwable t)
        {
            cause = t;
        }
        try
        {
            shutdownInput0();
        }
        catch (Throwable t)
        {
            if (cause == null)
            {
                promise.setFailure(t);
            }
            else
            {
                logger.debug(
                        "Exception suppressed because a previous exception occurred.",
                        t);
                promise.setFailure(cause);
            }
            return;
        }
        if (cause == null)
        {
            promise.setSuccess();
        }
        else
        {
            promise.setFailure(cause);
        }
    }

    @Override
    protected SocketAddress localAddress0()
    {
        return javaChannel().socket().getLocalSocketAddress();
    }

    @Override
    protected SocketAddress remoteAddress0()
    {
        return javaChannel().socket().getRemoteSocketAddress();
    }

    @Override
    protected void doBind(SocketAddress localAddress) throws Exception
    {
        doBind0(localAddress);
    }

    private void doBind0(SocketAddress localAddress) throws Exception
    {
        if (PlatformDependent.javaVersion() >= 7)
        {
            SocketUtils.bind(javaChannel(), localAddress);
        }
        else
        {
            SocketUtils.bind(javaChannel().socket(), localAddress);
        }
    }

    @Override
    protected boolean doConnect(SocketAddress remoteAddress,
            SocketAddress localAddress) throws Exception
    {
        if (localAddress != null)
        {
            doBind0(localAddress);
        }

        boolean success = false;
        try
        {
            boolean connected = SocketUtils.connect(javaChannel(),
                    remoteAddress);
            if (!connected)
            {
                selectionKey().interestOps(SelectionKey.OP_CONNECT);
            }
            success = true;
            return connected;
        }
        finally
        {
            if (!success)
            {
                doClose();
            }
        }
    }

    @Override
    protected void doFinishConnect() throws Exception
    {
        if (!javaChannel().finishConnect())
        {
            throw new Error();
        }
    }

    @Override
    protected void doDisconnect() throws Exception
    {
        doClose();
    }

    @Override
    protected void doClose() throws Exception
    {
        super.doClose();
        javaChannel().close();
    }

    @Override
    protected int doReadBytes(ByteBuf byteBuf) throws Exception
    {
        final RecvByteBufAllocator.Handle allocHandle = unsafe()
                .recvBufAllocHandle();
        allocHandle.attemptedBytesRead(byteBuf.writableBytes());
        return byteBuf.writeBytes(javaChannel(),
                allocHandle.attemptedBytesRead());
    }

    @Override
    protected int doWriteBytes(ByteBuf buf) throws Exception
    {
        final int expectedWrittenBytes = buf.readableBytes();
        return buf.readBytes(javaChannel(), expectedWrittenBytes);
    }

    @Override
    protected long doWriteFileRegion(FileRegion region) throws Exception
    {
        final long position = region.transferred();
        return region.transferTo(javaChannel(), position);
    }

    @Override
    protected void doWrite(ChannelOutboundBuffer in) throws Exception
    {
        for (;;)
        {
            int size = in.size();
            if (size == 0)
            {
                // All written so clear OP_WRITE
                clearOpWrite();
                break;
            }
            long writtenBytes = 0;
            boolean done = false;
            boolean setOpWrite = false;

            // Ensure the pending writes are made of ByteBufs only.
            ByteBuffer[] nioBuffers = in.nioBuffers();
            int nioBufferCnt = in.nioBufferCount();
            long expectedWrittenBytes = in.nioBufferSize();
            SocketChannel ch = javaChannel();

            // Always us nioBuffers() to workaround data-corruption.
            // See https://github.com/netty/netty/issues/2761
            switch (nioBufferCnt)
            {
                case 0:
                    // We have something else beside ByteBuffers to write so
                    // fallback to normal writes.
                    super.doWrite(in);
                    return;
                case 1:
                    // Only one ByteBuf so use non-gathering write
                    ByteBuffer nioBuffer = nioBuffers[0];
                    for (int i = config().getWriteSpinCount() - 1; i >= 0; i--)
                    {
                        final int localWrittenBytes = ch.write(nioBuffer);
                        if (localWrittenBytes == 0)
                        {
                            setOpWrite = true;
                            break;
                        }
                        expectedWrittenBytes -= localWrittenBytes;
                        writtenBytes += localWrittenBytes;
                        if (expectedWrittenBytes == 0)
                        {
                            done = true;
                            break;
                        }
                    }
                    break;
                default:
                    for (int i = config().getWriteSpinCount() - 1; i >= 0; i--)
                    {
                        final long localWrittenBytes = ch.write(nioBuffers, 0,
                                nioBufferCnt);
                        if (localWrittenBytes == 0)
                        {
                            setOpWrite = true;
                            break;
                        }
                        expectedWrittenBytes -= localWrittenBytes;
                        writtenBytes += localWrittenBytes;
                        if (expectedWrittenBytes == 0)
                        {
                            done = true;
                            break;
                        }
                    }
                    break;
            }

            // Release the fully written buffers, and update the indexes of the
            // partially written buffer.
            in.removeBytes(writtenBytes);

            if (!done)
            {
                // Did not write all buffers completely.
                incompleteWrite(setOpWrite);
                break;
            }
        }
    }

    @Override
    protected AbstractNioUnsafe newUnsafe()
    {
        return new NioSocketChannelUnsafe();
    }

    private final class NioSocketChannelUnsafe extends NioByteUnsafe
    {
        @Override
        protected Executor prepareToClose()
        {
            try
            {
                if (javaChannel().isOpen() && config().getSoLinger() > 0)
                {
                    // We need to cancel this key of the channel so we may not
                    // end up in a eventloop spin
                    // because we try to read or write until the actual close
                    // happens which may be later due
                    // SO_LINGER handling.
                    // See https://github.com/netty/netty/issues/4449
                    doDeregister();
                    return GlobalEventExecutor.INSTANCE;
                }
            }
            catch (Throwable ignore)
            {
                // Ignore the error as the underlying channel may be closed in
                // the meantime and so
                // getSoLinger() may produce an exception. In this case we just
                // return null.
                // See https://github.com/netty/netty/issues/4449
            }
            return null;
        }
    }

    private final class NioSocketChannelConfig
            extends DefaultSocketChannelConfig
    {
        private NioSocketChannelConfig(NioSocketChannel channel,
                Socket javaSocket)
        {
            super(channel, javaSocket);
        }

        @Override
        protected void autoReadCleared()
        {
            clearReadPending();
        }
    }
}
